from data_utils import load_mnist, load_cifar10
from models import run_all_models
from visualize import plot_results

# 主流程入口
if __name__ == "__main__":
    mnist_train, mnist_test = load_mnist()
    cifar_train, cifar_test = load_cifar10()
    print(f"MNIST训练集样本数: {len(mnist_train)}，测试集样本数: {len(mnist_test)}")
    print(f"CIFAR-10训练集样本数: {len(cifar_train)}，测试集样本数: {len(cifar_test)}")
    results = run_all_models(mnist_train, mnist_test, cifar_train, cifar_test)
    plot_results(results)
